Simple Examples
This tutorial goes through a few common ML tasks using the cremi dataset and a 2D U-Net.
Introduction and overview
In this tutorial we will cover a few basic ML tasks using the DaCapo toolbox. We will:
Prepare a dataloader for the CREMI dataset
Train a simple 2D U-Net for both instance and semantic segmentation
Visualize the results
Environment setup
If you have not already done so, you will need to install DaCapo. You can do this by first creating a new environment and then installing the DaCapo Toolbox.
I highly recommend using uv for environment management, but there are many tools to choose from.
uv init
uv add git+https://github.com/pattonw/dacapo-toolbox.git
Data Preparation
DaCapo works with zarr, so we will download CREMI Sample A and save it as a zarr file.
import wget
from pathlib import Path
import dask
dask.config.set(scheduler="single-threaded")
# Download some cremi data
# immediately convert it to zarr for convenience
if not Path("sample_A_20160501.hdf").exists():
wget.download(
"https://cremi.org/static/data/sample_C_20160501.hdf", "sample_C_20160501.hdf"
)
wget.download(
"https://cremi.org/static/data/sample_A_20160501.hdf", "sample_A_20160501.hdf"
)
Data Loading
We will use the funlib.persistence library to interface with zarr. This library adds support for units, voxel size, and axis names along with the ability to query our data based on a Roi object describing a specific rectangular piece of data. This is especially useful in a microscopy context where you regularly need to chunk your data for processing.
import numpy as np
from funlib.persistence import prepare_ds, open_ds
import h5py
from pathlib import Path
import re
if not Path("cremi.zarr/train/raw").exists():
test = h5py.File("sample_C_20160501.hdf", "r")
raw_data = test["volumes/raw"][:]
labels_data = test["volumes/labels/neuron_ids"][:]
test_raw = prepare_ds(
"cremi.zarr/test/raw",
raw_data.shape,
voxel_size=(40, 4, 4),
dtype=raw_data.dtype,
axis_names=["z", "y", "x"],
units=["nm", "nm", "nm"],
)
test_raw[test_raw.roi] = raw_data
test_labels = prepare_ds(
"cremi.zarr/test/labels",
labels_data.shape,
voxel_size=(40, 4, 4),
dtype=labels_data.dtype,
axis_names=["z", "y", "x"],
units=["nm", "nm", "nm"],
)
test_labels[test_labels.roi] = labels_data
train = h5py.File("sample_A_20160501.hdf", "r")
raw_data = train["volumes/raw"][:]
labels_data = train["volumes/labels/neuron_ids"][:]
train_raw = prepare_ds(
"cremi.zarr/train/raw",
raw_data.shape,
voxel_size=(40, 4, 4),
dtype=raw_data.dtype,
axis_names=["z", "y", "x"],
units=["nm", "nm", "nm"],
)
train_raw[train_raw.roi] = raw_data
train_labels = prepare_ds(
"cremi.zarr/train/labels",
labels_data.shape,
voxel_size=(40, 4, 4),
dtype=labels_data.dtype,
axis_names=["z", "y", "x"],
units=["nm", "nm", "nm"],
)
train_labels[train_labels.roi] = labels_data
else:
train_raw = open_ds("cremi.zarr/train/raw")
train_labels = open_ds("cremi.zarr/train/labels")
test_raw = open_ds("cremi.zarr/test/raw")
test_labels = open_ds("cremi.zarr/test/labels")
Lets visualize our train and test data
# a custom label color map for showing instances
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.animation as animation
from IPython.display import HTML
import matplotlib as mpl
mpl.rcParams["animation.embed_limit"] = 50_000_000 # 50 MB, for example
# Create a custom label color map for showing instances
np.random.seed(1)
colors = [[0, 0, 0]] + [list(np.random.choice(range(256), size=3)) for _ in range(255)]
label_cmap = ListedColormap(colors)
Training data
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
ims = []
for i, (x, y) in enumerate(zip(train_raw.data, train_labels.data)):
# Show the raw data
if i == 0:
im = axes[0].imshow(x)
axes[0].set_title("Raw Train Data")
im2 = axes[1].imshow(
y % 256, cmap=label_cmap, vmin=0, vmax=255, interpolation="none"
)
axes[1].set_title("Train Labels")
else:
im = axes[0].imshow(x, animated=True)
im2 = axes[1].imshow(
y % 256,
cmap=label_cmap,
vmin=0,
vmax=255,
animated=True,
interpolation="none",
)
ims.append([im, im2])
ims = ims + ims[::-1]
ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000)
video_html = ani.to_html5_video()
video_html = re.sub(r"<video ", '<video width="600" height="600" ', video_html)
HTML(video_html)
WARNING:matplotlib.animation:MovieWriter stderr:
Received > 3 system signals, hard exiting
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:224, in AbstractMovieWriter.saving(self, fig, outfile, dpi, *args, **kwargs)
223 try:
--> 224 yield self
225 finally:
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:1126, in Animation.save(self, filename, writer, fps, dpi, codec, bitrate, extra_args, metadata, extra_anim, savefig_kwargs, progress_callback)
1125 frame_number += 1
-> 1126 writer.grab_frame(**savefig_kwargs)
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:352, in MovieWriter.grab_frame(self, **savefig_kwargs)
351 # Save the figure data to the sink, using the frame format and dpi.
--> 352 self.fig.savefig(self._proc.stdin, format=self.frame_format,
353 dpi=self.dpi, **savefig_kwargs)
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/figure.py:3490, in Figure.savefig(self, fname, transparent, **kwargs)
3489 _recursively_make_axes_transparent(stack, ax)
-> 3490 self.canvas.print_figure(fname, **kwargs)
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/backend_bases.py:2184, in FigureCanvasBase.print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs)
2183 with cbook._setattr_cm(self.figure, dpi=dpi):
-> 2184 result = print_method(
2185 filename,
2186 facecolor=facecolor,
2187 edgecolor=edgecolor,
2188 orientation=orientation,
2189 bbox_inches_restore=_bbox_inches_restore,
2190 **kwargs)
2191 finally:
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/backend_bases.py:2040, in FigureCanvasBase._switch_canvas_and_return_print_method.<locals>.<lambda>(*args, **kwargs)
2039 skip = optional_kws - {*inspect.signature(meth).parameters}
-> 2040 print_method = functools.wraps(meth)(lambda *args, **kwargs: meth(
2041 *args, **{k: v for k, v in kwargs.items() if k not in skip}))
2042 else: # Let third-parties do as they see fit.
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/backends/backend_agg.py:417, in FigureCanvasAgg.print_raw(self, filename_or_obj, metadata)
416 raise ValueError("metadata not supported for raw/rgba")
--> 417 FigureCanvasAgg.draw(self)
418 renderer = self.get_renderer()
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/backends/backend_agg.py:382, in FigureCanvasAgg.draw(self)
380 with (self.toolbar._wait_cursor_for_draw_cm() if self.toolbar
381 else nullcontext()):
--> 382 self.figure.draw(self.renderer)
383 # A GUI class may be need to update a window using this draw, so
384 # don't forget to call the superclass.
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/artist.py:94, in _finalize_rasterization.<locals>.draw_wrapper(artist, renderer, *args, **kwargs)
92 @wraps(draw)
93 def draw_wrapper(artist, renderer, *args, **kwargs):
---> 94 result = draw(artist, renderer, *args, **kwargs)
95 if renderer._rasterizing:
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/artist.py:71, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
69 renderer.start_filter()
---> 71 return draw(artist, renderer)
72 finally:
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/figure.py:3257, in Figure.draw(self, renderer)
3256 self.patch.draw(renderer)
-> 3257 mimage._draw_list_compositing_images(
3258 renderer, self, artists, self.suppressComposite)
3260 renderer.close_group('figure')
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/image.py:134, in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
133 for a in artists:
--> 134 a.draw(renderer)
135 else:
136 # Composite any adjacent images together
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/artist.py:71, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
69 renderer.start_filter()
---> 71 return draw(artist, renderer)
72 finally:
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/axes/_base.py:3210, in _AxesBase.draw(self, renderer)
3208 _draw_rasterized(self.get_figure(root=True), artists_rasterized, renderer)
-> 3210 mimage._draw_list_compositing_images(
3211 renderer, self, artists, self.get_figure(root=True).suppressComposite)
3213 renderer.close_group('axes')
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/image.py:134, in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
133 for a in artists:
--> 134 a.draw(renderer)
135 else:
136 # Composite any adjacent images together
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/artist.py:71, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
69 renderer.start_filter()
---> 71 return draw(artist, renderer)
72 finally:
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/image.py:609, in _ImageBase.draw(self, renderer)
608 else:
--> 609 im, l, b, trans = self.make_image(
610 renderer, renderer.get_image_magnification())
611 if im is not None:
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/image.py:912, in AxesImage.make_image(self, renderer, magnification, unsampled)
910 clip = ((self.get_clip_box() or self.axes.bbox) if self.get_clip_on()
911 else self.get_figure(root=True).bbox)
--> 912 return self._make_image(self._A, bbox, transformed_bbox, clip,
913 magnification, unsampled=unsampled)
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/image.py:512, in _ImageBase._make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification, unsampled, round_to_pixel_border)
509 output_alpha = _resample( # resample alpha channel
510 self, A[..., 3], out_shape, t)
511 output = _resample( # resample rgb channels
--> 512 self, _rgb_to_rgba(A[..., :3]), out_shape, t)
513 elif np.ndim(alpha) > 0: # Array alpha
514 # user-specified array alpha overrides the existing alpha channel
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/ipykernel/kernelapp.py:600, in IPKernelApp.sigint_handler(self, *args)
599 elif self.kernel.shell_is_blocking:
--> 600 raise KeyboardInterrupt
KeyboardInterrupt:
During handling of the above exception, another exception occurred:
CalledProcessError Traceback (most recent call last)
Cell In[5], line 27
25 ims = ims + ims[::-1]
26 ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000)
---> 27 video_html = ani.to_html5_video()
28 video_html = re.sub(r"<video ", '<video width="600" height="600" ', video_html)
29 HTML(video_html)
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:1306, in Animation.to_html5_video(self, embed_limit)
1302 Writer = writers[mpl.rcParams['animation.writer']]
1303 writer = Writer(codec='h264',
1304 bitrate=mpl.rcParams['animation.bitrate'],
1305 fps=1000. / self._interval)
-> 1306 self.save(str(path), writer=writer)
1307 # Now open and base64 encode.
1308 vid64 = base64.encodebytes(path.read_bytes())
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:1098, in Animation.save(self, filename, writer, fps, dpi, codec, bitrate, extra_args, metadata, extra_anim, savefig_kwargs, progress_callback)
1093 return a * np.array([r, g, b]) + 1 - a
1095 # canvas._is_saving = True makes the draw_event animation-starting
1096 # callback a no-op; canvas.manager = None prevents resizing the GUI
1097 # widget (both are likewise done in savefig()).
-> 1098 with (writer.saving(self._fig, filename, dpi),
1099 cbook._setattr_cm(self._fig.canvas, _is_saving=True, manager=None)):
1100 if not writer._supports_transparency():
1101 facecolor = savefig_kwargs.get('facecolor',
1102 mpl.rcParams['savefig.facecolor'])
File ~/.local/share/uv/python/cpython-3.10.17-linux-x86_64-gnu/lib/python3.10/contextlib.py:153, in _GeneratorContextManager.__exit__(self, typ, value, traceback)
151 value = typ()
152 try:
--> 153 self.gen.throw(typ, value, traceback)
154 except StopIteration as exc:
155 # Suppress StopIteration *unless* it's the same exception that
156 # was passed to throw(). This prevents a StopIteration
157 # raised inside the "with" statement from being suppressed.
158 return exc is not value
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:226, in AbstractMovieWriter.saving(self, fig, outfile, dpi, *args, **kwargs)
224 yield self
225 finally:
--> 226 self.finish()
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:341, in MovieWriter.finish(self)
337 _log.log(
338 logging.WARNING if self._proc.returncode else logging.DEBUG,
339 "MovieWriter stderr:\n%s", err)
340 if self._proc.returncode:
--> 341 raise subprocess.CalledProcessError(
342 self._proc.returncode, self._proc.args, out, err)
CalledProcessError: Command '['ffmpeg', '-f', 'rawvideo', '-vcodec', 'rawvideo', '-s', '1200x600', '-pix_fmt', 'rgba', '-framerate', '5.0', '-loglevel', 'error', '-i', 'pipe:', '-vcodec', 'h264', '-pix_fmt', 'yuv420p', '-y', '/tmp/tmp72kydopw/temp.m4v']' returned non-zero exit status 123.
Testing data
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
ims = []
for i, (x, y) in enumerate(zip(test_raw.data, test_labels.data)):
if i == 0:
im = axes[0].imshow(x)
axes[0].set_title("Raw Test Data")
im2 = axes[1].imshow(
y % 256, cmap=label_cmap, vmin=0, vmax=255, interpolation="none"
)
axes[1].set_title("Test Labels")
else:
im = axes[0].imshow(x, animated=True)
im2 = axes[1].imshow(
y % 256,
cmap=label_cmap,
vmin=0,
vmax=255,
animated=True,
interpolation="none",
)
ims.append([im, im2])
ims = ims + ims[::-1]
ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000)
video_html = ani.to_html5_video()
video_html = re.sub(r"<video ", '<video width="600" height="600" ', video_html)
HTML(video_html)